class Solution {
public:
    vector<int> p;
    vector<int> size;

    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        p.resize(n);
        size.resize(n);
        for (int i = 0; i < n; ++i) p[i] = i;
        fill(size.begin(), size.end(), 1);
        vector<bool> clean(n, true);
        for (int i : initial) clean[i] = false;
        for (int i = 0; i < n; ++i) {
            if (!clean[i]) continue;
            for (int j = i + 1; j < n; ++j)
                if (clean[j] && graph[i][j] == 1) merge(i, j);
        }
        vector<int> cnt(n, 0);
        unordered_map<int, unordered_set<int>> mp;
        for (int i : initial) {
            unordered_set<int> s;
            for (int j = 0; j < n; ++j)
                if (clean[j] && graph[i][j] == 1) s.insert(find(j));
            for (int e : s) ++cnt[e];
            mp[i] = s;
        }
        int mx = -1, ans = 0;
        for (auto& [i, s] : mp) {
            int t = 0;
            for (int root : s)
                if (cnt[root] == 1)
                    t += size[root];
            if (mx < t || (mx == t && i < ans)) {
                mx = t;
                ans = i;
            }
        }
        return ans;
    }

    int find(int x) {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }

    void merge(int a, int b) {
        int pa = find(a), pb = find(b);
        if (pa != pb) {
            size[pb] += size[pa];
            p[pa] = pb;
        }
    }
};